#include "labelling.h"

// [[Rcpp::export]]
Graph district_graph(IntegerVector pl, int nd, const List &l) {
    int V = l.size();
    std::vector<std::vector<bool>> gr_bool;
    for (int i = 0; i < nd; i++) {
        std::vector<bool> tmp(nd, false);
        gr_bool.push_back(tmp);
    }
    
    for (int i = 0; i < V; i++) {
        IntegerVector nbors = (IntegerVector) l[i];
        int dist_i = pl[i] - 1;
        int nnbors = nbors.size();
        for (int j = 0; j < nnbors; j++) {
            int dist_j = pl[nbors[j]] - 1;
            if (dist_j != dist_i) {
                gr_bool[dist_i][dist_j] = true;
            }
        }
    }
    
    Graph out;
    for (int i = 0; i < nd; i++) {
        std::vector<int> tmp;
        for (int j = 0; j < nd; j++) {
            if (gr_bool[i][j]) {
                tmp.push_back(j);
            }
        }
        out.push_back(tmp);
    }
    
    return out;
}


double counter_helper(std::vector<bool> &A_in, int n_in, int add, const Graph &g, 
                      std::map<std::vector<bool>, double> &memos) {
    int n = A_in.size();
    if (n_in >= n - 2) return 0.0;
    A_in[add] = true;
    n_in++;
    auto search = memos.find(A_in);
    if (search != memos.end()) {
        A_in[add] = false;
        return search->second;
    } else {
        std::vector<double> xchild(n);
        std::vector<bool> skip(n, true);
        double max_x = 0.0;
        for (int i = 0; i < n; i++) {
            if (A_in[i]) continue;
            bool touches = false;
            std::vector<int> nbors = g[i];
            for (int nbor : nbors) {
                if (A_in[nbor]) {
                    touches = true;
                    break;
                }
            }
            if (!touches) continue;
            skip[i] = false;
            
            xchild[i] = counter_helper(A_in, n_in, i, g, memos);
            if (xchild[i] > max_x) 
                max_x = xchild[i];
        }
        
        double accuml = 0.0;
        for (int i = 0; i < n; i++) {
            if (!skip[i]) {
                accuml += std::exp(xchild[i] - max_x);
            }
        }
        
        std::vector<bool> A_key(A_in);
        double out = std::log(accuml) + max_x;
        memos.emplace(A_key, out);
        A_in[add] = false;
        return out;
    }
}

// [[Rcpp::export]]
NumericVector count_labelings_cpp(const List &l) {
    Graph g = list_to_graph(l);
    int n = g.size();
    std::map<std::vector<bool>, double> memos;
    std::vector<bool> A_in(n, false);
    
    std::vector<double> xchild(n);
    double max_x = 0.0;
    for (int i = 0; i < n; i++) {
        xchild[i] = counter_helper(A_in, 0, i, g, memos);
            if (xchild[i] > max_x) 
                max_x = xchild[i];
    }
    
    double accuml = 0.0;
    for (int i = 0; i < n; i++) {
        accuml += std::exp(xchild[i] - max_x);
    }
    
    NumericVector out(1);
    out[0] = std::log(accuml) + max_x;
    return out;
}


bool is_valid_sequential(Graph g, IntegerVector ordering) {
    int n = g.size();
    std::vector<bool> visited(n, false);
    visited[ordering[n-1]] = true;
    for (int i = n-2; i >= 0; i--) {
        int dist = ordering[i];
        std::vector<int> nbors = g[dist];
        int nnbors = nbors.size();
        
        bool ok = false;
        for (int j = 0; j < nnbors; j++) {
            if (visited[nbors[j]]) {
                ok = true;
                break;
            }
        }
        
        if (ok) {
            visited[dist] = true;
        } else {
            return false;
        }
    }
    
    return true;
}

// [[Rcpp::export]]
LogicalVector check_valid(IntegerVector pl, const List &l, int n_try=1) {
    int nd = max(pl);
    Graph dist_g = district_graph(pl, nd, l);
    LogicalVector out(n_try);
    for (int i = 0; i < n_try; i++) {
        IntegerVector perm = sample(nd, nd, false, R_NilValue, false);
        // print_vec(as<std::vector<int>>(perm));
        out[i] = is_valid_sequential(dist_g, perm);
    }
    return out;
}

// [[Rcpp::export]]
LogicalVector check_valid_gr(const List &l, int n_try=1) {
    int nd = l.size();
    Graph dist_g = list_to_graph(l);
    LogicalVector out(n_try);
    for (int i = 0; i < n_try; i++) {
        IntegerVector perm = sample(nd, nd, false, R_NilValue, false);
        out[i] = is_valid_sequential(dist_g, perm);
    }
    return out;
}


// [[Rcpp::export]]
NumericVector random_labelings(const List &l, NumericVector weights, int n) {
    Graph g = list_to_graph(l);
    int V = g.size();
    if (weights.size() != V) stop("`lp` must be of size `n`.");
    NumericVector lp(n);
    double tot_wgt = sum(weights);
    
    for (int i = 0; i < n; i++) {
        std::vector<bool> candidate(V, false);
        std::vector<bool> visited(V, false);
        
        double idx = tot_wgt * unif(generator);
        double accuml = 0;
        int vtx;
        for (vtx = 0; vtx < V - 1; vtx++) {
            accuml += weights.at(vtx);
            if (accuml >= idx) break;
        }
        lp[i] = std::log(weights.at(vtx)) - std::log(tot_wgt);
        
        visited[vtx] = true;
        std::vector<int> nbors = g[vtx];
        int n_nbors = nbors.size();
        double n_cands = 0;
        for (int k = 0; k < n_nbors; k++) {
            candidate.at(nbors[k]) = true;
            n_cands += weights.at(nbors[k]);
        }
        
        for (int j = 1; j < V; j++) {
            double idx = n_cands * unif(generator);
            double accuml = 0;
            int vtx;
            for (int k = 0; k < V; k++) {
                if (candidate.at(k)) {
                    vtx = k;
                    accuml += weights.at(vtx);
                    if (accuml >= idx) break;
                }
            }
            lp[i] += std::log(weights.at(vtx)) - std::log(n_cands);
            
            candidate.at(vtx) = false;
            visited.at(vtx) = true;
            std::vector<int> nbors = g.at(vtx);
            n_cands -= weights.at(vtx);
            int n_nbors = nbors.size();
            for (int k = 0; k < n_nbors; k++) {
                if (!visited.at(nbors[k]) && !candidate.at(nbors[k])) {
                    n_cands += weights.at(nbors[k]);
                    candidate.at(nbors[k]) = true;
                }
            }
        }
    }
    
    return lp;
}